问题发生在我使用Keras
的Resnet50
预训练模型时。我希望用Resnet50
作为一个主干网络,通过在后面添加全连接层来实现我的回归任务。训练的时候效果不错,loss
经过很多次训练后收敛了,下降到了一个较低的值。但是在测试阶段,预测结果很差,loss
很高,即使我使用了训练数据去预测。最后发现问题出现在了BN(Batch Normalization)
层。
问题介绍
首先需要说明的是,不同Keras版本的解决方法是不一样的。我的版本如下:
tensorflow = 2.2.0
由于tensorflow2.0已经集成了Keras,所以我使用的是tensorflow里的Keras,也就是这样:
网上关于BN层解决方案有些还是tensorflow1.x
时代的,因此在尝试的时候应该先确认你的版本。这里有一些我搜索的时候比较经典的几篇博客:
- https://zhuanlan.zhihu.com/p/56225304
- https://github.com/keras-team/keras/pull/9965
- http://blog.datumbox.com/the-batch-normalization-layer-of-keras-is-broken/#comment-22015
- https://stackoverflow.com/questions/47157526/resnet-100-accuracy-during-training-but-33-prediction-accuracy-with-the-same
注意他们的版本和时间,离我写这篇博客的时候已经一两年了,至少对于我的版本已经不可用了。
接下来我们简单看一下问题:
理论上来说,evaluate
出的结果应该和刚刚训练的结果一致。训练结束后,网络的参数都固定了,这时候把刚刚的训练数据放进去测试,结果应该和最后一次训练结果一致,可是结果大相径庭。问题就在于Resnet50
中的BN
层。
什么是BN层
BN
的基本思想:因为深层神经网络在做非线性变换前的激活输入值(就是那个y=wx+b
,x
是输入)随着网络深度加深或者在训练过程中,其分布逐渐发生偏移或者变动,之所以训练收敛慢,一般是整体分布逐渐往非线性函数的取值区间的上下限两端靠近,所以这导致反向传播时低层神经网络的梯度消失,这是训练深层神经网络收敛越来越慢的本质原因,而BN
就是通过一定的规范化手段,把每层神经网络任意神经元这个输入值的分布强行拉回到均值为0方差为1的标准正态分布,其实就是把越来越偏的分布强制拉回比较标准的分布,这样使得激活输入值落在非线性函数对输入比较敏感的区域,这样输入的小变化就会导致损失函数较大的变化,意思是这样让梯度变大,避免梯度消失问题产生,而且梯度变大意味着学习收敛速度快,能大大加快训练速度。
BN
在2014年由Loffe和Szegedy提出,它长这个样子:
就是把输入变成一个均值为0,方差为1的正态分布,再进行一个仿射变换(可以想象把正态分布图进行拉伸)。从名字中的Batch
就可以看出,这里的均值和方差都是对于一个batch
而言的。那么问题来了,训练时候好说,就用当前batch
来计算好了,测试的时候怎么办,如果测试的时候只有一个sample
呢?,是补成一个batch
还是怎么做呢?Keras
里是用移动均值和方差,也就是使用历史数据来计算。 所以问题就出在这儿,训练时和测试时,BN
层执行的操作是不一样的,因此训练和测试结果不一致。
解决方案
这里不讨论Keras
的历史问题,比如某个版本冻结了BN
层和没冻结一样等。我们只考虑当前版本如何处理。现在问题就是训练时和测试时执行的操作不一致,那么我们要么让测试去贴合训练,要么就训练阶段贴合测试阶段。
测试端修改
Keras
里是有一个变量learning_phase
来控制当前是训练还是测试模式的,理论上,我们可以在测试前,强制把模式设置为训练模式,这样测试时不就会按照训练阶段执行了吗?但是我试了没有用,我猜测是不管你外面怎么修改模式,它进行测试的时候都会改成测试模式。果然,在predict
的源码里,调用了下面这个函数:
可以看到,返回的参数training
是写死的False
,把它改成True
问题就迎刃而解。但是直接这样改不是万全之策,极有可能造成其他问题。当然你可以加一个判断,根据当前的learning_phase
值来决定training
的值。不过直接改包的源码毕竟不太推荐,所以我们可以使用第二种方法:修改训练端。
训练端修改
如果我们让训练阶段执行和测试阶段同样的操作,也是能解决问题的。TF
为后端时,BN
有一个参数是training
,控制归一化时用的是当前Batch
的均值和方差(训练模式)还是移动均值和方差(测试模式)。那我们只要把Resnet
的所有BN
层的training
都修改为测试模式即可。
实验一下就会发现,问题完美解决。